import torch
import torch.nn as nn
import torch.nn.functional as F


class SMTWTModel(nn.Module):

    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        self.encoder = SMTWT_Encoder(**model_params)
        self.decoder = SMTWT_Decoder(**model_params)
        self.encoded_nodes = None

    def pre_forward(self, reset_state, latent_var):
        if self.training:
            pomo_size = self.model_params['pomo_size']
        else:
            pomo_size = self.model_params['eval_pomo_size']

        problems = reset_state.problems # shape: (batch, job, feature)
        self.problems = problems.repeat_interleave(pomo_size, dim=0) # shape: (batch*pomo, job, feature)
        self.latent_var = latent_var

    def _get_new_data(self, problem, selected):

        BP, N,_ = problem.shape
        step = selected.size(-1)

        selected = selected.view(BP, step)
        mask = torch.zeros(BP, N, dtype=torch.bool, device=problem.device)
        mask.scatter_(1, selected, True)

        unselected_mask = ~mask
        new_data = problem[unselected_mask].view(BP, N - step, 3)
        indices = torch.arange(N, device=problem.device).unsqueeze(0).expand(BP, N)
        unselected_idx = indices[unselected_mask].view(BP, N - step) 
        return new_data, unselected_idx
    
    def _problem_update(self, problem, selected_node_list):
        bs, ps, _ = selected_node_list.shape
        js = problem.size(1)
        problem = problem.reshape(bs, ps, js, -1)

        processing_time = problem[..., 0]
        select = selected_node_list.long()
        selected_processing_time = torch.gather(processing_time, dim=2, index=select)
        sum_processing_time = selected_processing_time.sum(dim=2)

        due_date = problem[..., 1]
        sum_processing_time_expanded = sum_processing_time.unsqueeze(-1)
        new_due_date = due_date - sum_processing_time_expanded
        updated_data = problem.clone()
        updated_data[..., 1] = new_due_date
        updated_data = updated_data.reshape(bs*ps, js, -1)
        return updated_data
    
    def forward(self, state, selected_node_list):

        batch_size = state.BATCH_IDX.size(0)
        pomo_size = state.BATCH_IDX.size(1)
        updated_problem = self._problem_update(self.problems, selected_node_list)

        new_data, clean_idx = self._get_new_data(updated_problem, selected_node_list)
        all_job_probs = self.decoder(self.encoder(new_data), self.latent_var)
        if self.training or self.model_params['eval_type']=='softmax':
            with torch.no_grad():
                selected_rel = torch.multinomial(all_job_probs, num_samples=1).squeeze(dim=1)  # (BP,)
                selected = clean_idx.gather(1, selected_rel.unsqueeze(1)).squeeze(1).reshape(batch_size, pomo_size)
            prob = all_job_probs.gather(1, selected_rel.unsqueeze(1)).squeeze(1).reshape(batch_size, pomo_size)
        else: # Greedy
            with torch.no_grad():
                selected_rel = torch.argmax(all_job_probs, dim=1)  # (BP,)
                selected = clean_idx.gather(1, selected_rel.unsqueeze(1)).squeeze(1).reshape(batch_size, pomo_size) # (BP,)
            prob = all_job_probs.gather(1, selected_rel.unsqueeze(1)).squeeze(1).reshape(batch_size, pomo_size)

        return selected, prob

########################################
# ENCODER
########################################
class SMTWT_Encoder(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        self.embedding = nn.Linear(3, embedding_dim, bias=True)

    def forward(self, data):
        embedded_input = self.embedding(data)
        return embedded_input


class SMTWT_Decoder(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        decoder_layer_num = self.model_params['decoder_layer_num']
        lc_dim = self.model_params['latent_cont_size']
        ld_dim = self.model_params['latent_disc_size']
        l_dim = lc_dim+ld_dim

        self.layers = nn.ModuleList([DecoderLayer(**model_params) for _ in range(decoder_layer_num)])
        self.Linear_final = nn.Linear(embedding_dim+l_dim, 1, bias=True) #l_dim

    def forward(self,data, latent_var):
        out = data
        bs,ps,embed = latent_var.shape
        layer_count=0

        for layer in self.layers:
            out = layer(out)
            layer_count += 1
        latent_var = latent_var.reshape(bs*ps, 1, embed)
        latent_var = latent_var.expand(bs*ps, out.size(1), embed)
        mlp_input = torch.cat([out, latent_var], dim=-1)

        out = self.Linear_final(mlp_input).squeeze(-1)
        probs = F.softmax(out, dim=-1)

        return probs

class DecoderLayer(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        head_num = self.model_params['head_num']
        qkv_dim = self.model_params['qkv_dim']

        self.Wq = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)

        self.feedForward = Feed_Forward_Module(**model_params)


    def forward(self, input1):

        head_num = self.model_params['head_num']

        q = reshape_by_heads(self.Wq(input1), head_num=head_num)
        k = reshape_by_heads(self.Wk(input1), head_num=head_num)
        v = reshape_by_heads(self.Wv(input1), head_num=head_num)

        out_concat = multi_head_attention(q, k, v)

        multi_head_out = self.multi_head_combine(out_concat)

        out1 = input1 + multi_head_out
        out2 = self.feedForward(out1)
        out3 = out1 +  out2
        return out3

def reshape_by_heads(qkv, head_num):

    batch_s = qkv.size(0)
    n = qkv.size(1)

    q_reshaped = qkv.reshape(batch_s, n, head_num, -1)

    q_transposed = q_reshaped.transpose(1, 2)

    return q_transposed

def multi_head_attention(q, k, v):

    batch_s = q.size(0)
    head_num = q.size(1)
    n = q.size(2)
    key_dim = q.size(3)

    input_s = k.size(2)

    score = torch.matmul(q, k.transpose(2, 3))  # shape: (B, head_num, n, n)

    score_scaled = score / torch.sqrt(torch.tensor(key_dim, dtype=torch.float))

    weights = nn.Softmax(dim=3)(score_scaled)  # shape: (B, head_num, n, n)

    out = torch.matmul(weights, v)  # shape: (B, head_num, n, key_dim)

    out_transposed = out.transpose(1, 2)  # shape: (B, n, head_num, key_dim)

    out_concat = out_transposed.reshape(batch_s, n, head_num * key_dim)  # shape: (B, n, head_num*key_dim)

    return out_concat

class Feed_Forward_Module(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        embedding_dim = model_params['embedding_dim']
        ff_hidden_dim = model_params['ff_hidden_dim']

        self.W1 = nn.Linear(embedding_dim, ff_hidden_dim)
        self.W2 = nn.Linear(ff_hidden_dim, embedding_dim)

    def forward(self, input1):
        # input.shape: (batch, problem, embedding)

        return self.W2(F.relu(self.W1(input1)))